from torchtext.data import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
from typing import List, Iterable
import unicodedata
from torch.utils.data import Dataset, DataLoader
from config import *
import opencc

cc = opencc.OpenCC('t2s')


# all_letters = string.ascii_letters + " .,;'"
# 为便于数据处理，把Unicode字符串转换为ASCII编码
def unicodeToAscii(s):
    return ''.join(
        c for c in unicodedata.normalize('NFD', s)
        if unicodedata.category(c) != 'Mn'  # and c in all_letters
    )


def filterPair(p):
    return len(p[0]) < MAX_LENGTH and \
           len(p[1]) < MAX_LENGTH \
        # and p[0].startswith(eng_prefixes)


# eng_prefixes = (
#     "i am ", "i'm ",
#     "he is", "he's ",
#     "she is", "she's ",
#     "you are", "you're ",
#     "we are", "we're ",
#     "they are", "they're "
# )

# 使用spacy分词，列表存储
token_transform_en = get_tokenizer('spacy', language='en_core_web_sm')
token_transform_zh = get_tokenizer('spacy', language='zh_core_web_sm')



def ldata(path=DATA_ROOT + 'eng-cmn.txt', num_examples=10000):
    src = [] #[[1,2,3],[1,2,3]]
    tgt = []
    i = 0
    for line in open(path, encoding='utf-8'):
        pair = cc.convert(line).split('\t')
        pair[0] = token_transform_en(unicodeToAscii(pair[0].lower().strip()))  # 大写转小写，去空格
        pair[1] = token_transform_zh(unicodeToAscii(pair[1].strip()))
        if filterPair(pair):
            src.append(pair[0])
            tgt.append(pair[1])
        i += 1
        if i >= num_examples:
            break
    return src, tgt


class MyDataset(Dataset):
    def __init__(self, src, tgt, src_vocab, tgt_vocab, num_steps):  # 构造函数
        # [[how, are, you, ?],[...]...] -> [[8,9,10,11,3,1,1,1,1,1]],[...]...]
        self.src_array, self.src_valid_len = build_array_nmt(src, src_vocab, num_steps)
        self.tgt_array, self.tgt_valid_len = build_array_nmt(tgt, tgt_vocab, num_steps)

    #  src, len, tgt, len
    def __getitem__(self, idx):  # 读取数据
        return self.src_array[idx], self.src_valid_len[idx], self.tgt_array[idx], self.tgt_valid_len[idx]

    def __len__(self):
        return len(self.src_array)


# helper function to yield list of tokens，列表转迭代器
def yield_tokens(data_iter):
    for data_sample in data_iter:
        yield data_sample


def truncate_pad(line, num_steps, padding_token):
    """截断或填充文本序列"""
    if len(line) > num_steps:
        return line[:num_steps]  # 截断
    return line + [padding_token] * (num_steps - len(line))  # 填充


def build_array_nmt(lines, vocab, num_steps):
    def voc(line):
        return [vocab[i] for i in line]

    # 将机器翻译的文本序列转换成小批量
    lines = [voc(l) for l in lines]
    lines = [l + [vocab['<eos>']] for l in lines]
    array = torch.tensor([truncate_pad(
        l, num_steps, vocab['<pad>']) for l in lines])
    # 非pad处求和即为valid_len
    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
    return array, valid_len


def predata(num_steps, num_examples=20000):
    """数据集和词表"""
    src, tgt = ldata(num_examples=num_examples)
    src_vocab = build_vocab_from_iterator(yield_tokens(src), min_freq=1, specials=special_symbols)
    tgt_vocab = build_vocab_from_iterator(yield_tokens(tgt), min_freq=1, specials=special_symbols)
    src_vocab.set_default_index(0)  # <unk>
    tgt_vocab.set_default_index(0)
    dataset = MyDataset(src, tgt, src_vocab, tgt_vocab, num_steps)
    return dataset, src_vocab, tgt_vocab


def loaddata(dataset, batch_size, is_train=True):
    return DataLoader(dataset, batch_size, shuffle=is_train)


if __name__ == "__main__":
    dataset, eng, zh = predata(10, 20000) # 20000
    torch.save(dataset, DATA_ROOT + "dataset{}".format(num_examples))
    torch.save(eng, DATA_ROOT + "eng_vocab{}".format(num_examples))
    torch.save(zh, DATA_ROOT + "zh_vocab{}".format(num_examples))
